Source file: https://shirinsplayground.netlify.app/2021/03/kmeans_101/
Repeat steps 3 and 4 until you reach a stage when no points need to be reassigned.
Stop. You have found your k clusters and their centers!
If you want to learn more about k-Means, I would recommend this post on Medium, though be aware that the example code is all written in Python. If you are brave and want to go very deep in k-Means theory, take a look at the Wikipedia page. Or, if you would like to see one application of k-Means in R, see this blog’s post about using k-Means to help assist in image classification with Keras. For a detailed illustration of how to implement k-Means in R, along with answers to some common questions, keep reading below.
Setting up two functions:
# Define two functions for transforming a distribution of values
# into the standard normal distribution (bell curve with mean = 0
# and standard deviation (sd) = 1). More on this later.
normalize_values <- function(x, mean, sd) {
(x-mean)/sd
}
unnormalize_values <- function(x, mean, sd) {
(x*sd)+mean
}
set.seed(2021) # So you can reproduce this example
The data we will use for this example is from one of R’s pre-loaded datasets, quakes. It is a data.frame with 1000 rows and five columns describing earthquakes near Fiji since 1964. The columns are latitude (degrees), longitude (degrees), depth (km), magnitude (Richter scale), and the number of stations reporting the quake. The only pre-processing we will do now is to remove stations and convert this to a tibble.
quakes_raw <- quakes %>%
dplyr::select(-stations) %>%
dplyr::as_tibble()
summary(quakes_raw)
## lat long depth mag
## Min. :-38.59 Min. :165.7 Min. : 40.0 Min. :4.00
## 1st Qu.:-23.47 1st Qu.:179.6 1st Qu.: 99.0 1st Qu.:4.30
## Median :-20.30 Median :181.4 Median :247.0 Median :4.60
## Mean :-20.64 Mean :179.5 Mean :311.4 Mean :4.62
## 3rd Qu.:-17.64 3rd Qu.:183.2 3rd Qu.:543.0 3rd Qu.:4.90
## Max. :-10.72 Max. :188.1 Max. :680.0 Max. :6.40
k-Means calculates distance to the cluster center using Euclidian distance: the length of a line segment connecting the two points. In two dimensions, this is the Pythagorean Theorem. Aha, you say! I see the problem: we are comparing magnitudes (4.0-6.4) to depth (40-680). Depth has significantly more variation (standard deviation 0.4 for magnitude vs. 215 for depth) and therefore gets overweighted when calculating distance to the mean.
We need to employ feature scaling. As a general rule, if we are comparing unlike units (meters and kilograms) or independent measurements (height in meters and circumference in meters), we should normalize values, but if units are related (petal length and petal width), we should leave them as is.
Unfortunately, many cases require judgment both on whether to scale and how to scale. This is where your expert opinion as a data analyst becomes important. For the purposes of this blog post, we will normalize all of our features, including latitude and longitude, by transforming them to standard normal distributions. The geologists might object to this methodology for normalizing (magnitude is a log scale!!), but please forgive some imprecision for the sake of illustration.
# Create a tibble to store the information we need to normalize
# Tibble with row 1 = mean and row 2 = standard deviation
transformations <- dplyr::tibble(
lat = c(mean(quakes_raw$lat), sd(quakes_raw$lat)),
long = c(mean(quakes_raw$long), sd(quakes_raw$long)),
depth = c(mean(quakes_raw$depth), sd(quakes_raw$depth)),
mag = c(mean(quakes_raw$mag), sd(quakes_raw$mag))
)
# Use the convenient function we wrote earlier
quakes_normalized <- quakes_raw %>%
dplyr::mutate(
lat = normalize_values(
lat, transformations$lat[1], transformations$lat[2]
),
long = normalize_values(
long, transformations$long[1], transformations$long[2]
),
depth = normalize_values(
depth, transformations$depth[1], transformations$depth[2]
),
mag = normalize_values(
mag, transformations$mag[1], transformations$mag[2]
)
)
summary(quakes_normalized)
## lat long depth mag
## Min. :-3.56890 Min. :-2.27235 Min. :-1.2591 Min. :-1.54032
## 1st Qu.:-0.56221 1st Qu.: 0.02603 1st Qu.:-0.9853 1st Qu.:-0.79548
## Median : 0.06816 Median : 0.32095 Median :-0.2987 Median :-0.05065
## Mean : 0.00000 Mean : 0.00000 Mean : 0.0000 Mean : 0.00000
## 3rd Qu.: 0.59761 3rd Qu.: 0.61586 3rd Qu.: 1.0747 3rd Qu.: 0.69419
## Max. : 1.97319 Max. : 1.42812 Max. : 1.7103 Max. : 4.41837
kclust <- kmeans(quakes_normalized, centers = 4, iter.max = 10, nstart = 5)
str(kclust)
## List of 9
## $ cluster : int [1:1000] 1 1 2 1 1 3 4 2 2 1 ...
## $ centers : num [1:4, 1:4] -0.012 -1.736 0.294 0.934 0.222 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ : chr [1:4] "1" "2" "3" "4"
## .. ..$ : chr [1:4] "lat" "long" "depth" "mag"
## $ totss : num 3996
## $ withinss : num [1:4] 594 253 340 358
## $ tot.withinss: num 1546
## $ betweenss : num 2450
## $ size : int [1:4] 420 143 242 195
## $ iter : int 4
## $ ifault : int 0
## - attr(*, "class")= chr "kmeans"
Print the cluster assignments:
kclust
## K-means clustering with 4 clusters of sizes 420, 143, 242, 195
##
## Cluster means:
## lat long depth mag
## 1 -0.01202836 0.2224322 1.0714971 -0.25340992
## 2 -1.73556818 0.3865266 -0.8379212 0.32263747
## 3 0.29398939 0.8951082 -0.7466850 -0.07834941
## 4 0.93380886 -1.8733897 -0.7667092 0.40643880
##
## Clustering vector:
## [1] 1 1 2 1 1 3 4 2 2 1 1 4 1 1 4 3 4 1 1 1 1 4 1 2 1 1 4 1 1 3 1 4 3 1 3 1 4
## [38] 1 1 4 2 3 1 3 4 2 2 4 1 3 1 3 4 1 1 1 1 1 1 1 1 3 1 4 1 3 1 1 1 3 3 3 4 1
## [75] 1 1 3 4 1 2 2 1 1 3 1 3 4 1 3 3 4 4 1 4 3 1 2 3 4 1 3 1 1 2 1 3 2 4 2 2 3
## [112] 1 1 1 1 1 4 4 4 4 4 3 1 1 2 4 1 1 2 3 3 1 4 1 1 4 3 3 2 3 4 1 4 2 1 3 3 4
## [149] 1 1 2 4 3 4 4 1 4 1 4 4 1 1 4 2 2 2 2 3 1 4 1 1 1 3 1 2 1 3 1 3 1 3 3 1 1
## [186] 3 1 1 1 3 3 4 1 1 2 1 3 3 1 1 1 1 3 2 4 3 1 1 2 1 2 3 2 3 1 1 2 1 3 1 3 4
## [223] 3 1 1 4 3 3 2 4 1 1 1 3 1 1 2 1 4 1 3 3 4 3 3 1 1 3 1 4 4 4 1 4 2 1 3 4 1
## [260] 1 3 1 4 1 3 2 4 4 1 1 3 1 1 1 1 1 1 1 1 1 2 1 3 1 3 1 1 3 1 1 1 4 1 1 3 3
## [297] 1 1 2 4 1 3 3 1 1 1 1 1 1 2 4 4 1 1 3 1 1 4 1 4 4 3 1 2 4 2 1 4 3 4 3 1 2
## [334] 4 1 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 4 1 2 2 1 4 3 2 4 1 3 1 2 4 2 1 3 1 3
## [371] 1 3 1 1 3 3 1 3 3 3 4 4 2 4 1 2 1 1 4 2 3 2 1 1 1 1 1 3 1 4 2 4 2 1 3 1 1
## [408] 4 1 2 3 3 4 1 3 4 4 2 2 3 1 2 3 1 2 2 1 1 4 1 1 1 1 1 3 2 4 1 1 3 4 3 4 1
## [445] 3 3 3 1 1 1 2 1 4 3 3 4 1 3 1 1 3 1 1 1 2 1 4 3 1 2 3 3 3 4 3 2 2 2 1 3 1
## [482] 3 3 2 1 2 2 1 1 1 2 4 1 1 3 4 3 2 1 3 3 3 3 4 1 1 4 4 4 1 2 3 3 1 1 3 4 3
## [519] 3 1 3 1 2 4 2 1 4 4 1 2 4 4 3 4 3 4 1 4 4 1 4 4 4 4 4 4 4 1 3 1 1 4 4 3 3
## [556] 3 3 2 1 4 1 3 2 3 1 3 2 1 3 2 4 3 1 3 1 3 1 1 1 2 4 1 4 1 1 3 4 1 1 1 1 1
## [593] 4 4 3 4 4 1 3 3 2 2 1 1 1 2 3 3 1 2 2 4 4 1 3 1 3 4 1 4 2 2 1 1 4 1 2 1 4
## [630] 1 1 2 2 1 3 1 3 4 3 1 1 4 2 3 1 3 2 3 2 1 1 1 4 1 1 4 1 2 1 1 2 1 1 1 1 1
## [667] 1 3 1 3 1 4 1 1 4 2 3 1 1 1 4 3 1 1 3 4 1 1 1 1 1 1 1 1 1 1 1 3 4 1 4 1 1
## [704] 1 3 3 1 3 3 1 3 4 3 4 1 1 1 2 3 1 4 1 1 1 4 4 1 1 1 1 1 2 2 1 4 3 3 1 1 1
## [741] 1 3 3 2 3 2 2 3 1 1 1 3 1 4 1 4 3 3 4 3 3 1 3 3 4 4 3 3 4 3 3 1 3 1 1 1 3
## [778] 1 4 1 4 4 2 3 3 2 4 1 3 1 1 1 1 2 1 2 4 1 3 1 2 1 3 1 1 3 1 1 1 1 1 3 4 3
## [815] 1 1 3 3 1 1 1 3 1 1 4 1 4 3 3 3 2 3 3 1 1 3 4 1 3 4 1 3 2 4 1 1 1 1 1 1 3
## [852] 4 4 3 1 3 1 3 2 4 1 1 1 1 4 4 1 3 4 4 1 4 4 1 1 1 1 1 1 1 3 1 4 4 2 1 4 4
## [889] 3 2 4 1 4 3 1 1 3 1 3 1 2 3 2 1 1 1 1 4 4 2 1 4 3 1 4 4 3 3 3 1 4 3 1 1 2
## [926] 1 4 2 2 4 3 1 1 3 3 3 1 3 1 1 1 2 1 2 3 2 1 2 3 3 1 2 3 1 1 1 2 3 1 3 3 4
## [963] 4 3 3 1 1 1 4 2 2 2 3 4 3 3 3 2 1 4 1 4 2 2 3 4 4 4 1 3 1 1 3 1 3 1 4 3 3
## [1000] 4
##
## Within cluster sum of squares by cluster:
## [1] 594.2867 252.9391 340.4248 358.4645
## (between_SS / total_SS = 61.3 %)
##
## Available components:
##
## [1] "cluster" "centers" "totss" "withinss" "tot.withinss"
## [6] "betweenss" "size" "iter" "ifault"
point_assignments <- broom::augment(kclust, quakes_normalized) %>%
dplyr::mutate(
lat = unnormalize_values(
lat, transformations$lat[1], transformations$lat[2]
),
long = unnormalize_values(
long, transformations$long[1], transformations$long[2]
),
depth = unnormalize_values(
depth, transformations$depth[1], transformations$depth[2]
),
mag = unnormalize_values(
mag, transformations$mag[1], transformations$mag[2]
)
)
cluster_info <- broom::tidy(kclust) %>%
dplyr::mutate(
lat = unnormalize_values(
lat, transformations$lat[1], transformations$lat[2]
),
long = unnormalize_values(
long, transformations$long[1], transformations$long[2]
),
depth = unnormalize_values(
depth, transformations$depth[1], transformations$depth[2]
),
mag = unnormalize_values(
mag, transformations$mag[1], transformations$mag[2]
)
)
model_stats <- broom::glance(kclust)
head(point_assignments)
## # A tibble: 6 x 5
## lat long depth mag .cluster
## <dbl> <dbl> <dbl> <dbl> <fct>
## 1 -20.4 182. 562 4.8 1
## 2 -20.6 181. 650 4.2 1
## 3 -26 184. 42 5.4 2
## 4 -18.0 182. 626 4.1 1
## 5 -20.4 182. 649 4 1
## 6 -19.7 184. 195 4 3
Print the cluster assignments:
head(point_assignments)
## # A tibble: 6 x 5
## lat long depth mag .cluster
## <dbl> <dbl> <dbl> <dbl> <fct>
## 1 -20.4 182. 562 4.8 1
## 2 -20.6 181. 650 4.2 1
## 3 -26 184. 42 5.4 2
## 4 -18.0 182. 626 4.1 1
## 5 -20.4 182. 649 4 1
## 6 -19.7 184. 195 4 3
Model statistics:
model_stats
## # A tibble: 1 x 4
## totss tot.withinss betweenss iter
## <dbl> <dbl> <dbl> <int>
## 1 3996 1546. 2450. 4
Cluster information:
cluster_info
## # A tibble: 4 x 7
## lat long depth mag size withinss cluster
## <dbl> <dbl> <dbl> <dbl> <int> <dbl> <fct>
## 1 -20.7 181. 542. 4.52 420 594. 1
## 2 -29.4 182. 131. 4.75 143 253. 2
## 3 -19.2 185. 150. 4.59 242 340. 3
## 4 -15.9 168. 146. 4.78 195 358. 4
Plot the data with clusters:
plotly::plot_ly() %>%
plotly::add_trace(
data = point_assignments,
x = ~long, y = ~lat, z = ~depth*-1, size = ~mag,
color = ~.cluster,
type = "scatter3d", mode = "markers",
marker = list(symbol = "circle", sizemode = "diameter"),
sizes = c(5, 30)
) %>%
plotly::layout(scene = list(
xaxis = list(title = "Longitude"),
yaxis = list(title = "Latitude"),
zaxis = list(title = "Depth")
))
## Warning: `arrange_()` is deprecated as of dplyr 0.7.0.
## Please use `arrange()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## Warning: `line.width` does not currently support multiple values.
## Warning: `line.width` does not currently support multiple values.
## Warning: `line.width` does not currently support multiple values.
## Warning: `line.width` does not currently support multiple values.